#!/usr/bin/env python

# a simple transition path sampling with flexible
# author: Lixin Sun
import numpy as np
from read import read_trj
from write_lammps import run_oneway_commitor, find_basin
import time
from numpy import sqrt
import random
from math import sin, cos, pi
import os
norm = np.linalg.norm

def main():

    initial_config = "initial.npz"
    replica = 200
    steps = 4000
    timestep = 0.5
    damp = 50
    temp = 50
    natom = 22
    nswap = 10
    file_title = "TPE"
    log = open(f"{file_title}.log", "w", 1)
    kB = 8.6173303e-5 # eV/K

    mkdir()

    # initialization
    start = time.time()

    data = np.load(initial_config)
    all_pos = data['pos']
    all_vel = data['vel']
    all_pe = data['pe']
    all_alle = data['alle']
    all_intc = data['intc']
    all_label = data['label']

    config = np.argmax(all_pe)
    basin0, ends0 = find_basin(np.flip(all_intc[:config, :], axis=0))
    basin1, ends1 = find_basin(all_intc[config:, :])

    # all_pos = all_pos[config-ends0:config+ends1]
    # all_vel = all_vel[config-ends0:config+ends1]
    # all_pe = all_pe[config-ends0:config+ends1]
    # all_alle = all_alle[config-ends0:config+ends1]
    # all_intc = all_intc[config-ends0:config+ends1]
    # all_label = all_label[config-ends0:config+ends1]

    if (basin0==-1):
        basin0 = 0
    if (basin1 == -1):
        basin1 = 0

    allbasin = [basin0, basin1]
    if (2 in allbasin):
        if (0 == basin0):
            allbasin[0] = 1
        elif (0 == basin1):
            allbasin[1] = 1

    type_m = [1.008, 12.01, 12.01, 16.0 , 14.01, 1.008, 1.008]
    ele_type = [1, 2, 1, 1, 3, 4, 5, 6, 2, 7, 2, 1, 1, 1, 3, 4, 5, 6,
            2, 7, 7, 7]
    ele_m = np.array([type_m[i-1] for i in ele_type])
    vel_rescale = sqrt(3*kB*temp/ele_m*1.6e-2/1.66)

    n = len(ele_type)
    swap_group = []
    for i in range(n):
        for j in range(i+1, n):
            if (type_m[ele_type[i]-1] == type_m[ele_type[j]-1] ):
                swap_group += [[i, j]]
    swap_group = np.array(swap_group)
    nswap_group = len(swap_group)


    config = np.argmax(all_pe)
    old_vel = np.copy(data['vel'][config])
    old_temp = np.average(ele_m*norm(old_vel.reshape([-1, 3]), axis=1)**2/3.0/kB/1.6e-2*1.66)
    ratio = sqrt(temp/old_temp)
    vel = data['vel']*ratio

    # prepare the pdf
    tmax = 25
    k = 0.01
    t_list = np.arange(-tmax, tmax+1, 1)
    pt_list = np.exp(k*t_list)
    pt_list = pt_list / np.sum(pt_list)
    int_pt_list = np.array([np.sum(pt_list[:i]) for i in range(pt_list.shape[0])])


    n = len(ele_type)

    accept = 0
    reject = 0

    direction = ["backward", "forward"]

    for instance in range(replica):

        if (accept == 0):
            original_sp = np.argmax(all_pe)
        nconfig = len(all_pos)
        xdir, config = find_sp(int_pt_list, t_list, original_sp, nconfig)

        x0 = all_pos[config, :]
        velocities = (xdir-0.5)*2*all_vel[config]

        if (accept != 0):
            velocities = perturb_v(velocities, nswap, swap_group, nswap_group)

        intc0 = all_intc[config]
        pe0 = all_pe[config]

        newxyz, newvel, alle, intc, basin, ends = run_oneway_commitor(instance*2, "hello", x0, velocities, temp, steps, timestep, damp)

        np.savez(f"sp/config-{instance}.npz",
                pos = x0, vel=velocities, col=intc0, n=basin, pe=pe0)
        os.rename("hello.dat", f"sp/{instance}.dat")

        if (xdir == 1):
            oldlength = len(all_pe)-config
        else:
            oldlength = config


        cond1 = (basin == allbasin[xdir])
        # ratio = np.random.random()
        # cond2 = ( ratio < oldlength/float(ends))
        cond2 = True

        prints = f"{instance:5d} {direction[xdir]:8}"\
                f" trial pt{config:5d} from {nconfig:5d}"\
                f" Temp {np.min(alle[:, 0]):4.0f}"\
                f" to {np.max(alle[:, 0]):4.0f}"\
                f" max_pe {np.max(alle[:, 1]):8.2f},"\
                f" Et {alle[0, 3]:6.2f} to {alle[-1, 3]:6.2f}"\
                f" (old {oldlength} new {ends:5d})"

        label = np.zeros(newxyz.shape[0])
        label += basin
        if (basin!=-1):
            end_basin, new_ends = find_basin(intc[ends+1:])
            if (end_basin!=basin):
                label[ends+1:] += (end_basin-basin)

        if (cond1 and cond2):
            accept += 1
            result = "accept"
            if (xdir == 1):
                all_pos = np.vstack([all_pos[:config+1], newxyz[:ends+1]])
                all_vel = np.vstack([all_vel[:config+1], newvel[:ends+1]])
                all_intc = np.vstack([all_intc[:config+1], intc[:ends+1]])
                all_alle = np.vstack([all_alle[:config+1], alle[:ends+1]])
                all_pe = np.hstack([all_pe[:config+1], alle[:ends+1, 1]])
                all_label = np.hstack([all_label[:config+1], label[:ends+1]])
                original_sp = config
            else:
                all_pos = np.vstack([np.flip(newxyz[:ends+1], axis=0),
                    all_pos[config:]])
                all_vel = np.vstack([-np.flip(newvel[:ends+1], axis=0),
                    all_vel[config:]])
                all_intc = np.vstack([np.flip(intc[:ends+1], axis=0),
                    all_intc[config:]])
                all_alle = np.vstack([np.flip(alle[:ends+1], axis=0),
                    all_alle[config:]])
                all_pe = np.hstack([np.flip(alle[:ends+1, 1], axis=0),
                    all_pe[config:]])
                all_label = np.hstack([np.flip(label[:ends+1], axis=0),
                    all_label[config:]])
                original_sp = ends
        else:
            result = "reject"

        prints += f" {result:10} target {allbasin[xdir]:2} actual {basin:2d}"
        print(prints)
        print(prints, file=log)

        np.savez(f"result/{file_title}-{instance}-{result}.npz",
                pos = newxyz, vel=newvel, label=label, intc=intc, alle=alle, pe=alle[:, 1])
        np.savez(f"pe-only/{file_title}-{instance}-{result}.npz",
                pos = newxyz[:ends+1], vel=newvel[:ends+1],
                label=label[:ends+1], intc=intc[:ends+1],
                alle=alle[:ends+1],
                pe=alle[:ends+1, 1])
        np.savez(f"current_path.npz",
                pos = all_pos, vel=all_vel,
                label=all_label, intc=all_intc,
                alle=all_alle, pe=all_pe)
        del newxyz
        del newvel
        del intc
        del alle
        del label

    print(f"acceptance rate {accept/float(replica)}")
    print(f"acceptance rate {accept/float(replica)}", file=log)

    end = time.time()
    print(f"total time: {end-start}")
    print(f"total time: {end-start}", file=log)


def mkdir():
    try:
        os.mkdir("result")
    except:
        pass
    try:
        os.mkdir("sp")
    except:
        pass
    try:
        os.mkdir("pe-only")
    except:
        pass

def find_sp(int_pt_list, t_list, original_sp, nconfig):

    accept = False
    while accept is False:
        r1 = np.random.rand()*2-1
        if (r1>0):
            s = 1
            xdir = 1
            ns = 'nb'
        else:
            s = -1
            ns = 'na'
            xdir = 0
        # determin dtao
        r2 = np.random.rand()
        idk = np.argmin(np.abs(int_pt_list-r2))
        if (r2 > int_pt_list[idk]):
            tao = t_list[idk]
        elif ((idk+1)<len(t_list)):
            tao = t_list[idk+1]
        else:
            tao = t_list[-1]

        tao = s*tao

        config = original_sp + tao
        if (config >= nconfig or config < 0):
            accept = False
        else:
            accept = True
    return xdir, config

def perturb_v(vel, nswap, swap_group, n):
    v = vel.reshape([-1, 3])
    igroup = random.sample(range(n), nswap)
    for i in range(nswap):
        a = swap_group[igroup[i], 0]
        b = swap_group[igroup[i], 1]
        v[[a, b]] = v[[b, a]]
    return v.reshape([-1])




if __name__ == '__main__':
    main()


